# Copyright 2013 Al Cramer
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import sys

"""
Script for generating semantic indices. It reads an xml version of
roget as input ("roget.xml"), and writes n ascii file of synonym lists
("synlists.txt") as output. Additional input files are:
1. "mobythesuarus.txt" -- the Moby Thesaurus.
2. "wrdcounts.txt" -- ascii giving word counts for the 5,000 most common
  English words, as anlayzed over a very large corpus.

The output file generated by this script is "synlists.xml".
Various intermediary files are also created. You need to manually
delete them.
"""

# part of speach bitmasks
PS_N = 0x1
PS_V = 0x2
PS_Mod = 0x4
PS_Ignore = 0x8

# Dictionary entry
class Wrd:
    def __init__(self,sp):
        self.sp = sp
        self.cnt = 0
        self.props = 0
        
    def checkProp(self,m):
        return (self.props & m) != 0

    def printWrd(self,fp):
        p = []
        if self.checkProp(PS_N):
            p.append('N')
        if self.checkProp(PS_V):
            p.append('V')
        if self.checkProp(PS_Mod):
            p.append('Mod')
        if self.checkProp(PS_Ignore):
            p.append('Ignore')
        fp.write('%s %d %s\n' % (self.sp, self.cnt, '|'.join(p)))
    
# read "wordcnts.txt" and create the lexicon
lexicon = {}
def getLexicon():
    fp = open('wordcnts.txt','r')
    lines = fp.readlines()
    fp.close()
    for li in lines:
        terms = li.split()
        assert len(terms)==4
        cnt = int(terms[1].strip())
        sp = terms[2].strip()
        spPos = terms[3].strip()
        v = PS_Ignore
        if spPos == 'n':
            v = PS_N
        elif spPos == 'v':
            v = PS_V
        elif spPos == 'a' or spPos == 'adv':
            v = PS_Mod
        e = lexicon.get(sp)
        if e is None:
            e = Wrd(sp)
            lexicon[sp] = e
        e.props |= v
        e.cnt += cnt
    # special cases we want to disqualify
    ignore = ['be',"one's",'not','do','she-']
    for sp in ignore:
        e = lexicon.get(sp)
        if e is None:
            e = lexicon[sp] = Wrd(sp)
        e.props = PS_Ignore

# print stats re the lexicon
def printLexiconStats():
    nV = 0
    nN = 0
    nMod = 0
    for key,w in lexicon.iteritems():
        if (w.props & PS_V) != 0:
            nV += 1
        elif (w.props & PS_N) != 0:
            nN += 1
        elif (w.props & PS_Mod) != 0:
            nMod += 1
    print 'lexicon. nWords: %d' %(nV+nN+nMod)
    print '  verbs: %d' % nV
    print '  nouns: %d' % nN
    print '  modifiers: %d' % nMod

    
# print a dictionary of words
def printDct(fp,dct):
    tmp = []
    for sp,e in dct.iteritems():
        tmp.append([e.cnt,e])
    tmp.sort()
    tmp.reverse()
    for sp,e in tmp:
        e.printWrd(fp)

# is "w" a word we should ignore?
def ignoreWord(w):
    e = lexicon.get(w)
    return e is not None and e.checkProp(PS_Ignore)

# Roget Data Model
# a Group is an id and semi-colon delimited set of synonym lists
class Group:
    def __init__(self,name,synlsts):
        self.name = name
        self.synlsts = synlsts

# a section is an id and a collection of groups
class Section:
    def __init__(self,name,order):
        self.name = name
        self.order = order
        self.groups = []
        self.dct = {}

# Data model is a collection of sections
class Dom:
    def __init__(self):
        self.sections = {}

    def writeDom(self,fn):
        tmp = []
        for name,s in self.sections.iteritems():
            tmp.append([s.order,s])
        tmp.sort()
        fp = open(fn,'w');
        for [order,s] in tmp:
            fp.write('+s %s\n' % s.name)
            for g in s.groups:
                fp.write('+g %s\n' % g.name)
                fp.write('%s\n' % g.synlsts);
        fp.close()
        
    def readDom(self,fn):
        fp = open(fn,'r');
        curSect = None
        sectOrder = 1
        while True:
            li = fp.readline()
            if not li:
                break
            li = li.strip()
            if li.startswith('+s'):
                name = li.split()[1].strip()
                curSect = self.sections[name] = Section(name,sectOrder)
                sectOrder += 1
                continue
            if li.startswith('+g'):
                grpName = li.split()[1].strip()
                li = fp.readline().strip()
                curSect.groups.append(Group(grpName,li))
                continue

    def resolveXref(self,sectId,keyword):
        sect = self.sections.get(sectId)
        if sect is None:
            return ''
        resolved = []
        # walk thru the groups in the section
        for g in sect.groups:
            # split g into synlst's
            for synlst in g.synlsts.split(';'):
                if synlst.find(keyword) != -1:
                    synlst = re.sub(r'#[\d\w:\.]*',' ',synlst)
                    resolved.append(synlst)
        return ','.join(resolved)

# helper for "makeRogetExp": expand xref's in a synlist
def expandSynLst(dom,raw):
    #return raw
    cooked = raw
    while True:
        S = cooked.find('#')
        if S == -1:
            break
        E = S
        while E+1<len(cooked) and \
            cooked[E+1] != ' ' and \
            cooked[E+1] != ',' and \
            cooked[E+1] != ';':
            E += 1
        lhs = cooked[:S]
        rhs = cooked[E+1:]
        repl = ' '
        xref = cooked[S:E+1]
        xrefSplit = xref.split(':')
        if len(xrefSplit) == 2:
            sectName = xrefSplit[0]
            sectName = sectName[1:]
            keyword = xrefSplit[1]
            resolved = dom.resolveXref(sectName,keyword)
            if len(resolved)>0:
                repl = ',' + dom.resolveXref(sectName,keyword) + ','
        cooked = lhs + repl + rhs
    return cooked
        
# create "rogetexp.txt": roget, with the cross ref's expanded
def makeRogetExp():
    fin = open('roget.xml','r')
    fout = open('tmp.txt','w')
    while True:
        li = fin.readline()
        if not li:
            break
        if len(li.strip())==0:
            continue
        if li.startswith('<section'):
            m = re.search(r'"([^"]*)"',li)
            fout.write('+s %s\n' % m.group(1))
            continue
        if li.startswith('<g'):
            m = re.search(r'"([^"]*)"',li)
            fout.write('+g %s\n' % m.group(1))
            sense = []
            while True:
                li = fin.readline()
                li = li.strip()
                if li.endswith(';'):
                    li = li[:-1]
                if len(li)==0:
                    continue
                if li.startswith('<'):
                    break
                sense.append(li)
            fout.write('%s\n' % ';'.join(sense))
    fin.close()
    fout.close()
    # read in as a dom
    dom = Dom()
    dom.readDom('tmp.txt')
    # dev code: write out
    # dom.writeDom('tmp1.txt')
    # create a new dom, "dom1": it's dom with the cross ref's
    # expanded.
    dom1 = Dom()
    for sectId,sect in dom.sections.iteritems():
        # print sect.name
        sect1 = Section(sect.name,sect.order)
        dom1.sections[sect.name] = sect1
        # walk the groups of "sect", adding groups to sect1
        for g in sect.groups:
##            print "group:%s" % g.name
##            if g.name == '812.V.1':
##                debug = 1
            synlsts1 = []
            for synlst in g.synlsts.split(';'):
                synlst = synlst.strip()
                synlsts1.append(expandSynLst(dom,synlst))
            sl = ';'.join(synlsts1)
            sl = re.sub(r'[ ]*,',',',sl)
            sl = re.sub(r'[,]+,',',',sl)
            sl = re.sub(r';,',';',sl)
            sl = re.sub(r',;',';',sl)
            g1 = Group(g.name,sl)
            sect1.groups.append(g1)
    dom1.writeDom('rogetexp.txt')    

# This lexicion covers all sections
rogetLexicon = {}
# create roget and section-specific lexicons
def makeSectionLexicons(dom):
    for sName,s in dom.sections.iteritems():
        for g in s.groups:
            for sl in g.synlsts.split(';'):
                sl = sl.strip()
                if len(sl)==0:
                    continue
                for w in sl.split(','):
                    w = w.strip()
                    # skip phrases
                    if len(w)==0 or w.count(' ')>0:
                        continue
                    if ignoreWord(w):
                        continue
                    # section dictionary
                    dct = s.dct
                    e = dct.get(w)
                    if e is None:
                        e = dct[w] = Wrd(w)
                    e.cnt += 1
                    # roget lexicon
                    dct = rogetLexicon
                    e = dct.get(w)
                    if e is None:
                        e = dct[w] = Wrd(w)
                    e.cnt += 1

# get keyword for a group (a list of synlsts; each synlst is
# represented as a set of terms)
def getGroupKey(dct,synlsts):
    # get counts
    cnts = {}
    for synlst in synlsts:
        for terms in synlst:
            for w in terms.split(' '):
                if ignoreWord(w):
                    continue
                if cnts.get(w) is None:
                    cnts[w] = 0
                cnts[w] += 1
    # find most significant wrd
    best = None
    for w,cnt in cnts.iteritems():
        if best is None:
            best = w
            continue
        if cnts[w] > cnts[best]:
            best = w
            continue
        if cnts[w] == cnts[best]:
            # a tie. Go for more frequent word in dictionary.
            if dct.get(w) is not None and \
                dct.get(best) is not None:
                if dct[w].cnt > dct[best].cnt:
                    best = w
                    continue
            # go with lexigraphic ordering
            if w < best:
                best = w
    return best

# get keyword for a synonym set
def getSynsetKey(dct,aset):
    best = None
    for w in aset:
        if w.find(' ') != -1:
            # reject compound words as keys
            continue
        if best is None:
            best = w
            continue
        if dct.get(best) is not None and \
            dct.get(w) is not None:
            if dct[w].cnt > dct[best]:
                best = w
                continue
            if dct[w].cnt < dct[best]:
                # leave as is
                continue
        # go with lexigraphic ordering
        if w < best:
            best = w
    return best
                
def printSectionLexicons(fn,dom):
    fp = open(fn,"w")
    tmp = []
    for sName,s in dom.sections.iteritems():
        tmp.append([s.order,s])
    tmp.sort()
    for order,s in tmp:
        fp.write("-->%d. %s\n" % (order,s.name))
        printDct(fp,s.dct)
    fp.close()
    
# Our basic data structure is the mapping str->{strs},
# aka. strToSetOfStrs. One example is "keyToWrds", which
# maps a key to the sets of words which are mapped to that key.
# A second example is "wrdToKeys", which maps a word to its key.
# Note that wrdToKeys is the inverse of keyToWrds.
class StrToSetOfStrs:
    def __init__(self):
        self.dct = {}

    def addMapping(self,key,element):
        entry = self.dct.get(key)
        if entry is None:
            entry = self.dct[key] = set()
        entry.add(element)

    def invert(self):
        inverted = StrToSetOfStrs()
        for key,aset in self.dct.iteritems():
            for e in aset:
                inverted.addMapping(e,key)
        return inverted

    # validate: remove mappings to null-set
    def validate(self):
        dctPrime = {}
        for key,aset in self.dct.iteritems():
            if len(aset) > 0:
                dctPrime[key] = aset
        self.dct = dctPrime
        
    # is set1 a "fuzzy" subset of set2? This just means:
    # are most of set1's elements contained in set?.
    def isFuzzySubset(self,set1,set2):
        if len(set1)==0 or len(set2)==0:
            return False
        lenIsect = float(len(set1.intersection(set2)))
        lenSet1 = float(len(set1))
        return lenIsect/lenSet1 >= .75
    
    # Consider the mappings key1->{set1}, and key2->{set2}.
    # If set1 is a subset of set2, we say the mapping
    # key2->{set2} is a super-mapping for key1->{set1}.
    # This method finds all super-mappings in the collection.
    # It returns the dictionary key->{key}, where "key" is the
    # key for a mapping and {keyx} gives the keys of its super
    # mappings. If some given key_i has no entry in the dictionary,
    # then key_i has no super-mapping.
    # "useFuzzyTest" means: use "isFuzzySubset" to determine whether
    # setA is a subset of setB.
    def getSuperMappings(self,useFuzzyTest=False):
        # "startWrds": set of lexgraphically smallest words
        # that start sets.
        startWrds = set()
        for key,aset in self.dct.iteritems():
            alst = list(aset)
            alst.sort()
            if len(alst)>0:
                startWrds.add(alst[0])
        # "containsWrd": mapping, startWrd->key, giving the keys of
        # mappings which contain "startWrd".
        containsWrd = {}
        for key,aset in self.dct.iteritems():
            for w in aset:
                if w in startWrds:
                    if containsWrd.get(w) is None:
                        containsWrd[w] = set()
                    containsWrd[w].add(key)
        # mapping, key->{super-keys}
        keyToSuperKeys = {}
        for key,aset in self.dct.iteritems():
            alst = list(aset)
            if len(alst) == 0:
                continue
            alst.sort()
            w = alst[0]
            # loop thru mappings that contain "w"
            for keyx in containsWrd[w]:
                if keyx == key:
                    # this is us! skip
                    continue
                setx = self.dct[keyx]
                if useFuzzyTest:
                    asetIsSubset = self.isFuzzySubset(aset,setx)
                else:
                    asetIsSubset = aset.issubset(setx)
                if asetIsSubset:
                    # "keyx" is a super-mapping for key
                    if keyToSuperKeys.get(key) is None:
                        keyToSuperKeys[key] = set()
                    keyToSuperKeys[key].add(keyx)
        return keyToSuperKeys

    # kill redundant mappings. Suppose keyx is a super-mapping
    # for key, and both share the same synset keyword. Then
    # the mapping "key" is redundant.
    def killRedundantMappings(self):
        keyToSuperKeys = self.getSuperMappings()
        for key,superkeys in keyToSuperKeys.iteritems():
            for keyx in superkeys:
                aset = self.dct[key]
                setx = self.dct[keyx]
                if getSynsetKey(rogetLexicon,aset) == \
                    getSynsetKey(rogetLexicon,setx):
##                    print 'kill %s because of %s' %(key,superkey)
##                    print self.dct[key]
##                    print self.dct[superkey]
##                    print '----'
                    self.dct[key] = []
                    break
        self.validate()       
                    
    def refactor(self):
        # TODO: document this next step
        self.killRedundantMappings()
        nModified = 0
        keyToSuperKeys = self.getSuperMappings(True)
        for key, superkeys in keyToSuperKeys.iteritems():
            aset = self.dct[key]
            if len(aset) < 4:
                continue
            coreWrds = set()
            for w in aset:
                if lexicon.get(w) is None:
                    continue
                if not ignoreWord(w):
                    coreWrds.add(w)
            for keyx in superkeys:
                setx = self.dct[keyx]
                if not self.isFuzzySubset(aset,setx):
                    continue
                orig = setx.copy()
                setx = setx.difference(aset)
                # with fuzzy subsets, it's possible for setx to be empty
                if len(setx) > 0:
                    setx = setx.union(coreWrds)
                self.dct[keyx] = setx
                nModified += 1
##                print 'orig: %s' % str(orig)
##                print 'became: %s'% str(setx)
##                print '----'
        self.validate()
        print 'refactor. Changed %d sets. nSynsets: %d' % \
              (nModified,len(self.dct))


    def mergeSimilar(self):
        # TODO: document this next step
        nModified = 0
        keyToSuperKeys = self.getSuperMappings(True)
        for key, superkeys in keyToSuperKeys.iteritems():
            aset = self.dct[key]
            if len(aset) < 4:
                continue
            for keyx in superkeys:
                setx = self.dct[keyx]
                if not (self.isFuzzySubset(aset,setx) and \
                    self.isFuzzySubset(setx,aset)):    
                    continue
                orig = setx.copy()
                # reset aset as union of aset and setx; kill setx
                self.dct[key] = aset.union(setx)
                self.dct[keyx] = []
                nModified += 1
##                print 'orig: %s' % str(orig)
##                print 'became: %s'% str(setx)
##                print '----'
        self.validate()
        print '\nmergeSimilar modified %d sets. nSynsets: %d' %\
            (nModified,len(self.dct))

    def killSubsets(self):
        # TODO: document this next step
        nModified = 0
        keyToSuperKeys = self.getSuperMappings(True)
        for key, superkeys in keyToSuperKeys.iteritems():
            aset = self.dct[key]
            if len(aset) == 0:
                continue
            for keyx in superkeys:
                setx = self.dct[keyx]
                if self.isFuzzySubset(aset,setx):
                    self.dct[keyx] = setx.union(aset)
                    self.dct[key] = []
                    nModified += 1
##                    print 'killed %s' % str(aset)
##                    print 'because of: %s'% str(setx)
##                    print '----'
                    break
        self.validate()
        print '\nkillSubsets killed %d sets. nSynsets: %d' %\
            (nModified,len(self.dct))

    def printMapping(self,fn):
        tmp = []
        for key,aset in self.dct.iteritems():
            lst = list(aset)
            lst.sort()
            tmp.append([key,','.join(lst)])
        tmp.sort()
        if fn is None:
            fp = sys.stdout
        else:
            fp = open(fn,"w")
        for (key,lst) in tmp:
            fp.write('%s:\n%s\n\n' % (key,lst))
        if fp != sys.stdout:
            fp.close()
        return len(tmp)

    # unit test
    @classmethod
    def ut(cls):
        map1 = StrToSetOfStrs()
        map1.addMapping('fruit.fruit','apple')
        map1.addMapping('fruit.fruit','orange')
        map1.addMapping('fruit.fruit','tomato')
        map1.addMapping('vegetable.vegetable','squash')
        map1.addMapping('vegetable.vegetable','tomato')
        map1.printMapping(None)
        # test invert
        map2 = map1.invert()
        map2.printMapping(None)
        # test super-mapping
        map1.addMapping('fruit.red','tomato')
        map1.addMapping('fruit.red','apple')
        map1.addMapping('vegetable.red','tomato')
        print '--> test case:'
        map1.printMapping(None)
        print 'the super-mappings:'
        for key,superKeys in map1.getSuperMappings().iteritems():
            print '%s -> %s' % (key,str(superKeys))
        debug = 1
        
# expand using synonym lists from Moby Thesaurus
def expandByMoby(keyToWrds):
    wrdToKeys = keyToWrds.invert()
    fin = open("mobythesaurus.txt","r")
    nWrdsAdded = 0
    lno = 0
    while True:
        li = fin.readline()
        if len(li)==0:
            break
        lno += 1
        if lno % 2000 == 0:
            print "expandByMoby line %s"%lno
        synlst = li.split(',')
        mcore = synlst[0]
        nCore = 0
        # many syns appear in our core vocab?
        for w in synlst:
            if lexicon.get(w) != None:
                nCore += 1
        if nCore > 3:
            # skip: word is too general
            continue
        keyCnt = {}
        for w in synlst:
            keys = wrdToKeys.dct.get(w)
            if keys is not None:
                for key in keys:
                    if keyCnt.get(key) is None:
                        keyCnt[key] = 0
                    keyCnt[key] += 1
        addedMobyWrd = False
        for key,cnt in keyCnt.iteritems():
            # require at least ? words in the moby synlist to appear
            # in the mapping
            if cnt < 3:
                continue
            # add the moby core word to the mapping
            keyToWrds.addMapping(key,mcore)
            addedMobyWrd = True
        if addedMobyWrd:
            nWrdsAdded += 1
    fin.close()
    print 'expandByMoby added %d words.' % nWrdsAdded
                
    
if __name__== "__main__":
##    StrToSetOfStrs.ut()
##    exit(1)
    makeRogetExp()
    print "expanded cross ref's..."
    getLexicon()
    dom = Dom()
    dom.readDom('rogetexp.txt')
    makeSectionLexicons(dom)
    printSectionLexicons('tmp2.txt',dom)
    print "created section lexicons..."
    keyToWrds = StrToSetOfStrs()
    for sName,s in dom.sections.iteritems():
        for g in s.groups:
            # split the group up into synlsts
            synlsts = []
            for sl in g.synlsts.split(';'):
                sl = sl.strip()
                if len(sl) == 0:
                    continue
                terms = set()
                for term in sl.split(','):
                    term = term.strip()
                    if term.startswith('be '):
                        term = term[3:]
                    if len(term) == 0 or ignoreWord(term):
                        continue
                    terms.add(term)
                synlsts.append(terms)
            grpKey = getGroupKey(s.dct,synlsts)
            if grpKey is None:
                continue
            for sl in synlsts:
                if len(sl) <= 1:
                    continue
                synsetKey = getSynsetKey(s.dct,sl)
                if synsetKey is None:
                    continue
                key = '%s.%s' % (synsetKey,grpKey)
                for term in sl:
                    keyToWrds.addMapping(key,term)
           
    # expand synlists using the Moby Thesaurus
    expandByMoby(keyToWrds)               
    # refactor. 5 refactors seems to yield stability.
    for cnt in range(5):
        print '\nrefactor pass %d' % cnt
        keyToWrds.refactor()
    keyToWrds.mergeSimilar()
    keyToWrds.killSubsets()
    # invert the mapping to create wrd->{keys}. Then walk
    # the core vocab: if a core word isn't mapped to at least
    # one key, create a singleton entry for it.
    nSingletons = 0
    wrdToKeys = keyToWrds.invert()
    for sp,wrd in lexicon.iteritems():
        if (wrd.props & PS_Ignore) != 0:
            continue
        if wrdToKeys.dct.get(sp) is None:
            # print 'adding singleton for: %s' % sp
            keyToWrds.addMapping('%s.%s'%(sp,sp),sp)
            nSingletons += 1
    print 'created %d singletons' % nSingletons
    # write out the synlists
    synlsts = []
    for key,aset in keyToWrds.dct.iteritems():
        lst = list(aset)
        lst.sort()
        synlsts.append(','.join(lst))
    synlsts.sort()
    fp = open('synlists.txt','w')
    for lst in synlsts:
        fp.write('%s\n' % lst)
    fp.close()
    print 'created %d entries in "synlists.txt"' % len(synlsts)
    # write mapping wrd->(synlistIx}. Start the enumeration sequence
    # at 1. This is for dev/testing. Also get some stats on terms in
    # synlsts: number of single words vs. number of phrases.
    wrdToLstIx = {}
    i = 1
    for lst in synlsts:
        for w in lst.split(','):
            if wrdToLstIx.get(w) is None:
                wrdToLstIx[w] = []
            wrdToLstIx[w].append(str(i))
        i += 1
    tmp = []
    nSingleWords = 0
    nPhrases = 0
    for w,lst in wrdToLstIx.iteritems():
        tmp.append([w,','.join(lst)])
        if w.find(' ') == -1:
            nSingleWords += 1
        else:
            nPhrases += 1
    tmp.sort()
    fp = open('wordtolsts.txt','w')
    for w,lst in tmp:
        fp.write("%s:%s\n"%(w,lst))
    fp.close()
    print 'Synlst vocabulary: %d (%d words, %d phrases)' % \
          (nSingleWords+nPhrases,nSingleWords,nPhrases)
    printLexiconStats()

        
        
                
                
        
        

    
    
